import os
import pandas as pd
import yaml
import pickle

import pickle as pkl
import numpy as np
import networkx as nx
from pygenstability import pygenstability as pgs
from geometric_clustering import cluster_signed_modularity, load_curvature
from pygenstability.pygenstability import _evaluate_VI


from cdlib import algorithms


def _convert_coms(communities, n_nodes):
    _comms = np.zeros(n_nodes, dtype=int)
    for com, node in enumerate(communities):
        _comms[node] = com
    return _comms


if __name__ == "__main__":
    m = "jaccard"

    graph = nx.read_gpickle(os.path.join("data", "hox_gene_expression_" + m + ".gpickle"))
    print(len(graph))
    graph = nx.convert_node_labels_to_integers(graph)
    data = pkl.load(open("data/hox_gene_expression.pkl", "rb"))[:-1]
    ground_truth_1 = np.unique(data["Neurotransmitter"].to_list(), return_inverse=True)[1]
    ground_truth_2 = np.unique(data["Neuron Class"].to_list(), return_inverse=True)[1]

    os.chdir(m)
    results = {}
    n_comms = {}
    data = pd.DataFrame()
    forbiden_alg = [
        "belief",
        "ego_networks",
        "percomvc",
        "egonet_splitter",
        "nnsed",
        "edmot",
        "ga",
        "der",
        "big_clam",
        "cmp",
        "core_expansion",
        "leiden",
        "dcs",
        "scd",
        "danmf",
        "gemsec",
    ]
    for alg in dir(algorithms):
        if alg not in forbiden_alg:
            try:
                print(alg)
                res = getattr(algorithms, alg)(graph)
                community_id = _convert_coms(res.communities, len(graph))
                data.loc[alg, "n_comms"] = len(res.communities)
                vi = _evaluate_VI((0, 1), [community_id, ground_truth_2])
                print(alg, vi)
                data.loc[alg, "vi"] = float(vi)
            except Exception as ex:
                print(alg, "lkljk", ex)
                pass

    res = algorithms.louvain(graph, resolution=0.1)
    comms = _convert_coms(res.communities, len(graph))
    data.loc["louvain_01", "vi"] = float(_evaluate_VI((0, 1), [comms, ground_truth_2]))
    data.loc["louvain_01", "n_comms"] = len(res.communities)

    res = algorithms.louvain(graph, resolution=0.01)
    comms = _convert_coms(res.communities, len(graph))
    data.loc["louvain_001", "vi"] = float(_evaluate_VI((0, 1), [comms, ground_truth_2]))
    data.loc["louvain_001", "n_comms"] = len(res.communities)

    data.reset_index().to_csv("results_others.csv", index=False)
